{
 "metadata": {
  "name": "",
  "signature": "sha256:c542ab20c992e59a7fd483f020b9895c56365523c2879300f240096b19865944"
 },
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn import datasets\n",
      "import matplotlib.pyplot as plt\n",
      "import random\n",
      "import numpy as np\n",
      "%matplotlib inline"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 1
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "iris = datasets.load_iris()\n",
      "print (sum(iris.target==1))\n",
      "print (len(iris.target))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "50\n",
        "150\n"
       ]
      }
     ],
     "prompt_number": 2
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "data_size = sum(iris.target==0) + sum(iris.target==1)\n",
      "index = [i for i in range(data_size)]\n",
      "random.shuffle(index)\n",
      "iris.data = iris.data[index]\n",
      "iris.target = iris.target[index]\n",
      "\n",
      "#split data\n",
      "test_size = 20\n",
      "x_train = iris.data[test_size:]\n",
      "x_test = iris.data[:test_size]\n",
      "y_train = iris.target[test_size:]\n",
      "y_test = iris.target[:test_size]\n",
      "\n",
      "print (x_train.shape)\n",
      "print (y_test.shape)\n",
      "print (x_test)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "(80, 4)\n",
        "(20,)\n",
        "[[ 5.4  3.4  1.5  0.4]\n",
        " [ 5.4  3.4  1.7  0.2]\n",
        " [ 5.5  2.5  4.   1.3]\n",
        " [ 5.8  2.6  4.   1.2]\n",
        " [ 4.9  2.4  3.3  1. ]\n",
        " [ 5.6  3.   4.1  1.3]\n",
        " [ 5.   3.2  1.2  0.2]\n",
        " [ 4.8  3.4  1.9  0.2]\n",
        " [ 5.5  2.6  4.4  1.2]\n",
        " [ 5.7  2.8  4.5  1.3]\n",
        " [ 5.7  2.6  3.5  1. ]\n",
        " [ 4.5  2.3  1.3  0.3]\n",
        " [ 4.6  3.6  1.   0.2]\n",
        " [ 6.7  3.1  4.4  1.4]\n",
        " [ 4.7  3.2  1.3  0.2]\n",
        " [ 5.1  3.8  1.9  0.4]\n",
        " [ 5.2  3.4  1.4  0.2]\n",
        " [ 5.4  3.9  1.3  0.4]\n",
        " [ 4.9  3.1  1.5  0.1]\n",
        " [ 5.   2.3  3.3  1. ]]\n"
       ]
      }
     ],
     "prompt_number": 3
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def sigmoid(x):\n",
      "    return 1.0/(1.0+np.exp(-x))\n",
      "\n",
      "def cost(xMat, yMat, w):\n",
      "    left = np.multiply(yMat, np.log(sigmoid(xMat*w)))\n",
      "    right = np.multiply(1-yMat, np.log(1-sigmoid(xMat*w)))\n",
      "    return np.sum(left+right)/(-len(xMat))\n",
      "\n",
      "def train(xdata, ylabel, lr = 0.001, epoches = 10000):\n",
      "    xMat = np.mat(xdata)\n",
      "    yMat = np.mat(ylabel).T\n",
      "    \n",
      "    costList = []\n",
      "    m,n = np.shape(xMat)\n",
      "    w = np.mat(np.ones((n,1)))\n",
      "    b = np.mat(1.0)\n",
      "    for i in range(epoches):\n",
      "        a = sigmoid(xMat*w + b)\n",
      "        w_grad = np.dot(xMat.T, (a-yMat))/m\n",
      "        w -= np.dot(lr,w_grad)\n",
      "        b_grad = np.sum((a-yMat))/m\n",
      "        b -= lr*b_grad\n",
      "        if i % 50 == 0:\n",
      "            costList.append(cost(xMat, yMat, w))\n",
      "    return w, b, costList\n",
      "    \n",
      "    \n",
      "    "
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 4
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "w, b, costList = train(x_train, y_train)\n",
      "print (w)\n",
      "print (b)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "[[-0.68322684]\n",
        " [-0.74911581]\n",
        " [ 1.59736318]\n",
        " [ 1.37068057]]\n",
        "[[ 0.57062438]]\n"
       ]
      }
     ],
     "prompt_number": 5
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "x = np.linspace(0,10000,200)\n",
      "print (x.shape)\n",
      "print (len(costList))\n",
      "plt.plot(x, costList, c='r')\n",
      "plt.title(\"Train\")\n",
      "plt.xlabel('Epochs')\n",
      "plt.ylabel('Cost')\n",
      "plt.show()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "(200,)\n",
        "200\n"
       ]
      },
      {
       "metadata": {},
       "output_type": "display_data",
       "png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEWCAYAAABliCz2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAFzxJREFUeJzt3XuwXWd53/HvoyPJkiVZt3OkSDZw\nTCDEThp8EYwJkBK3IYYwYdLSYAKJm0AdaNqayyTBw7QzmTLTNqQe4oYkuMQtDbZx6uCEOA4OF7c0\nU1f4qBjjCy7GOPJdkm1ZkhXdn/6x1pa2js5lS9prr3P2+/3MrFnXvde7zpJ+693vukVmIkkafgva\nLoAkaTAMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj40imKiJGI2BMRL227LFIvwuvwVYqI2NM1eiaw\nHzhcj/9qZt4w+FJJg2Pgq0gR8Sjwvsz8ygzLLMzMQ4MrldQsm3SkWkR8PCJujoibImI38J6IeF1E\n/J+I2BkRT0XEtRGxqF5+YURkRIzX45+r5/9VROyOiLsi4twWN0k6joEvHe/ngBuBlcDNwCHgKmAU\neD1wGfCrM3z+F4B/DawBtgL/tsnCSifDwJeO9zeZ+ReZeSQz/y4z787MzZl5KDMfAa4D/v4Mn78l\nMycy8yBwA3DBQEot9WBh2wWQ5pjHukci4oeB/whcTHWidyGweYbPP901vBdY3u8CSqfKGr50vMlX\nMXwauA94RWaeBfwbIAZeKqkPDHxpZiuAF4AXI+I8Zm6/l+Y0A1+a2UeAK4DdVLX9m9stjnTqvA5f\nkgphDV+SCmHgS1IhDHxJKoSBL0mFmFM3Xo2Ojub4+HjbxZCkeWPLli07MnOsl2XnVOCPj48zMTHR\ndjEkad6IiL/tdVmbdCSpEAa+JBXCwJekQjTahl+/VWg31WvkDmXmpibXJ0ma3iBO2v5kZu4YwHok\nSTOwSUeSCtF04Cfw1xGxJSKunGqBiLgyIiYiYmL79u0NF0eSytV04L8hMy8C3gL8WkT8xOQFMvO6\nzNyUmZvGxnq6d2DyF8DHPw533HH6pZWkIdZo4GfmE3V/G3Ar8Nq+ryQCPvEJuP32vn+1JA2TxgI/\nIpZFxIrOMPBmqlfF9d/YGNgcJEkzavIqnfXArRHRWc+NmfmlRtZk4EvSrBoL/Mx8BHh1U99/nLEx\n2Lp1IKuSpPlqOC7LHB21hi9JsxiOwO806fh+Xkma1vAE/sGDsGtX2yWRpDlreAIfbNaRpBkY+JJU\nCANfkgph4EtSIYYr8Hf4FGZJms5wBP6yZbB0qTV8SZrBcAQ++HgFSZqFgS9JhTDwJakQBr4kFcLA\nl6RCDFfg791bdZKkEwxX4IO1fEmaxvAE/uho1TfwJWlKwxP41vAlaUYGviQVwsCXpEIMT+CvXAmL\nFhn4kjSN4Qn8iOrErU/MlKQpDU/ggzdfSdIMDHxJKoSBL0mFMPAlqRDDF/gvvAAHDrRdEkmac4Yv\n8MErdSRpCsMZ+DbrSNIJDHxJKoSBL0mFaDzwI2IkIr4ZEbc1vS4fkSxJ0xtEDf8q4MEBrAfWrKke\nsWDgS9IJGg38iDgH+BngM02u56iREVi71sCXpCk0XcP/JPAbwJHpFoiIKyNiIiImtvcjqL35SpKm\n1FjgR8TbgG2ZuWWm5TLzuszclJmbxjonXU/H2JjX4UvSFJqs4b8e+NmIeBT4PHBpRHyuwfVVrOFL\n0pQaC/zMvDozz8nMceBy4GuZ+Z6m1neUgS9JUxqu6/ChCvxnn4XDh9suiSTNKQMJ/Mz8H5n5tkGs\ni7ExyITnnhvI6iRpvhjOGj7YrCNJkxj4klQIA1+SCmHgS1Ihhi/wfYCaJE1p+AJ/0SJYtcrAl6RJ\nhi/wwZuvJGkKwxn4o6MGviRNMpyBbw1fkk4wvIHvEzMl6TjDHfiZbZdEkuaM4Q38gwfhhRfaLokk\nzRnDG/hgO74kdTHwJakQBr4kFcLAl6RCGPiSVIjhDPylS2HZMgNfkroMZ+CDd9tK0iQGviQVwsCX\npEIY+JJUiOEPfJ+nI0nAMAf+6Cjs2wd797ZdEkmaE4Y38L0WX5KOY+BLUiEMfEkqhIEvSYUw8CWp\nEMMb+CtWwOLFBr4k1YY38CO8+UqSugxv4IOBL0ldGgv8iFgSEd+IiG9FxP0R8VtNrWtaBr4kHdVk\nDX8/cGlmvhq4ALgsIi5pcH0nMvAl6aiFTX1xZiawpx5dVHeDfbCNgS9JRzXahh8RIxFxD7AN+HJm\nbp5imSsjYiIiJrb3O5zHxmD3bti/v7/fK0nzUKOBn5mHM/MC4BzgtRHxo1Msc11mbsrMTWOda+f7\nxWvxJemogVylk5k7gTuBywaxvqM6gb9jx0BXK0lzUZNX6YxFxKp6eCnwU8B3mlrflEZHq741fElq\n7qQtsAH4bESMUB1Y/iQzb2twfSeySUeSjmryKp17gQub+v6eGPiSdNRw32m7ejWMjBj4ksSwB/6C\nBbB2rYEvSQx74IM3X0lSzcCXpEIY+JJUCANfkgrRU+BHxB/3Mm1OGhuD556DQ4faLokktarXGv6P\ndI/UN1Nd3P/iNKBzLf6zz7ZbDklq2YyBHxFXR8Ru4MciYlfd7aZ6+uWfD6SEp8ubryQJmCXwM/Pf\nZeYK4BOZeVbdrcjMtZl59YDKeHoMfEkCem/SuS0ilgFExHsi4pqIeFmD5eofn5gpSUDvgf8HwN6I\neDXwEeB7wH9rrFT9ZA1fkoDeA/9Q/crCtwO/l5mfAlY0V6w+Wru26hv4kgrX69Myd0fE1cAvAm+M\niAVU76id+xYurB6iZuBLKlyvNfx3AvuBX8nMp6leWfiJxkrVb958JUm9BX4d8jcAKyPibcC+zJwf\nbfhg4EsSvd9p+/PAN4B/Avw8sDki3tFkwfrKwJekntvwPwa8JjO3QfW+WuArwC1NFayvxsbgrrva\nLoUktarXNvwFnbCvPXsSn23f2Fh1Hf6RI22XRJJa02sN/0sRcQdwUz3+TuD2ZorUgLExOHwYdu6E\nNWvaLo0ktWLGwI+IVwDrM/PXI+IfAW+oZ91FdRJ3fui++crAl1So2ZplPgnsAsjML2TmhzPzw8Ct\n9bz5wbttJWnWwF+fmd+ePLGeNt5IiZpg4EvSrIG/aoZ5S/tZkEYZ+JI0a+BPRMQ/mzwxIt4HbGmm\nSA3wiZmSNOtVOh8Ebo2Id3Ms4DcBi4Gfa7JgfXXGGbBihTV8SUWbMfAz8xngxyPiJ4EfrSf/ZWZ+\nrfGS9Zt320oqXE/X4WfmncCdDZelWQa+pMLNn7tlT5eBL6lw5QT+6KiBL6lo5QR+p4af2XZJJKkV\njQV+RLwkIu6MiAci4v6IuKqpdfVkbAz274c9e1othiS1pcka/iHgI5l5PnAJ8GsRcX6D65uZN19J\nKlxjgZ+ZT2Xm/62HdwMPAmc3tb5ZGfiSCjeQNvyIGAcuBDZPMe/KiJiIiIntTYaxgS+pcI0HfkQs\nB/4U+GBm7po8PzOvy8xNmblprBPKTTDwJRWu0cCPiEVUYX9DZn6hyXXNysCXVLgmr9IJ4I+ABzPz\nmqbW07Nly2DJEgNfUrGarOG/HvhF4NKIuKfu3trg+mYWcezdtpJUoF7faXvSMvNvgGjq+0+Jj1eQ\nVLBy7rQFA19S0Qx8SSqEgS9JhSgv8PfsgX372i6JJA1cWYE/Olr1reVLKlBZge/NV5IKVlbgr1tX\n9bdta7ccktQCA1+SClFW4K9fX/UNfEkFKivwly+vnqfzzDNtl0SSBq6swI+oavnW8CUVqKzAh6od\n3xq+pAKVF/jW8CUVqrzAt4YvqVBlBv62bZDZdkkkaaDKC/z16+HQIXj++bZLIkkDVV7ge/OVpEKV\nF/idm69sx5dUmPIC3xq+pEKVF/jW8CUVqrzAX7u2uuPWGr6kwpQX+CMj1YtQrOFLKkx5gQ/ebSup\nSGUG/rp18PTTbZdCkgaqzMDfsMHAl1SccgP/qad8vIKkopQb+Pv3+3gFSUUpN/ChquVLUiHKDPyN\nG6u+gS+pIGUGvjV8SQUy8CWpEGUG/ooVsGyZgS+pKI0FfkRcHxHbIuK+ptZxWjqXZkpSIZqs4f9X\n4LIGv//0bNgATz7ZdikkaWAaC/zM/DrwXFPff9o2brSGL6korbfhR8SVETERERPbt28f3Ipt0pFU\nmNYDPzOvy8xNmblpbGxscCvesAFefBF27x7cOiWpRa0Hfmu8NFNSYQx8A19SIZq8LPMm4C7gVRHx\neES8t6l1nZJzzqn6jz3WbjkkaUAWNvXFmfmupr67L172sqr/6KOtFkOSBqXcJp2lS6tXHRr4kgpR\nbuADjI8b+JKKUXbgn3uugS+pGGUH/vg4bN0Khw+3XRJJapyBf/Cgz9SRVAQDH2zWkVSEsgP/3HOr\nvoEvqQBlB/5LX1r1DXxJBSg78JcsqR6xYOBLKkDZgQ9Vs873v992KSSpcQb+D/4gPPRQ26WQpMYZ\n+BdeWF2W+cwzbZdEkhpl4F98cdXfsqXdckhSwwz8Cy+ECANf0tAz8FesgB/6IQNf0tAz8KFq1pmY\naLsUktQoAx+qwH/iCU/cShpqBj4cO3F7993tlkOSGmTgA7zmNbByJdx0U9slkaTGGPgAZ54JV1wB\nt9wC27a1XRpJaoSB3/H+98OBA3D99W2XRJIaYeB3nHcevOlNcM01XqIpaSgtbLsAc8rv/z685S3w\nxjdW/TVr4Lnn4NlnYceOqn/kCCxeXHWrVsH69bBuXdVfvx42boSzz676GzfC0qVtb5UkAQb+8c47\nDzZvhg99CL75TXj+eRgdrbpXvQrWroWRkarpZ//+av62bXDffdUlnQcOnPidq1efeBDoDHf669fD\nQneFpGaZMpOtXw833njyn8uEnTvhqaeqa/qffPJYvzP8wAPV/MkvTV+w4Nivg86BYN06GBurutHR\n44cXLerPtkoqioHfLxFVbX71ajj//OmXO3wYtm+f+oDw5JOwdSvcdVfVfJQ59XesXHniwWDt2mrd\nq1ZN31+8uJltlzQvGPiDNjICP/ADVXfRRdMvd+hQdf5gx47qANHpJo9v3VqdZH7uOdi3b+Z1n3nm\n9AeDFStm7s46q+ovX15tg6R5x8CfqxYurJp11q3r/TP79lXNSs8/X3Wd4en6TzxRnX/YuRN27z6x\nqWk6Z5554gFh+XJYtqya1+lmG59q2uLF1a8lSX1n4A+TJUuO/Xo4WZnVAWP37um7Xbumn/fMM7B3\n77HuxRer/skaGamCf+nSans63RlnHD8+VXeyy3SutprcLVrkrxgNJQNflYgqZJcuPblfFTPpHES6\nDwCTDwjTje/dW10JtW/f8d3u3VVT1r59U8+f7rzHyVqwYPoDQuegMNP8mZZZtKj6Bdfp9zp8qp9Z\n4O02qhj4ak73QWTt2ubXlwkHDx4L/6kOCN3dgQPV8gcOTN/NNv/Ageq7du3q7XvasGBB7wePkZFT\n707n8/387IIFVdc9PN202cZnWiZi3jU/Nhr4EXEZ8LvACPCZzPz3Ta5PhYs4Vos+66y2S3OizOpk\n/IEDVb/THTx46sP9/PzBg9V5nKm6/funn9fpDh2afZlO169fYm2L6M+BZd06+PrXGy9uY4EfESPA\np4CfAh4H7o6IL2bmA02tU5rTIqqatPdRVIF/qgeL2Q42ne8+cuRYN3l8qmltfebIkYFVUJqs4b8W\neDgzHwGIiM8DbwcMfKl0EceakzQwTZ7NORt4rGv88XqaJKkFrZ++j4grI2IiIia2b9/ednEkaWg1\nGfhPAC/pGj+nnnaczLwuMzdl5qaxsbEGiyNJZWsy8O8GXhkR50bEYuBy4IsNrk+SNIPGzphk5qGI\n+BfAHVSXZV6fmfc3tT5J0swaPUWembcDtze5DklSb1o/aStJGgwDX5IKETmHbnGOiO3A357ix0eB\nHX0sznzgNg+/0rYX3OaT9bLM7OkSxzkV+KcjIiYyc1Pb5Rgkt3n4lba94DY3ySYdSSqEgS9JhRim\nwL+u7QK0wG0efqVtL7jNjRmaNnxJ0syGqYYvSZqBgS9JhZj3gR8Rl0XEQxHxcER8tO3ynI6IeElE\n3BkRD0TE/RFxVT19TUR8OSK+W/dX19MjIq6tt/3eiLio67uuqJf/bkRc0dY29SIiRiLimxFxWz1+\nbkRsrrfr5vrhe0TEGfX4w/X88a7vuLqe/lBE/HQ7W9K7iFgVEbdExHci4sGIeN0w7+eI+FD9b/q+\niLgpIpYM436OiOsjYltE3Nc1rW/7NSIujohv15+5NuIkX6qbmfO2o3oo2/eAlwOLgW8B57ddrtPY\nng3ARfXwCuD/AecDvw18tJ7+UeA/1MNvBf4KCOASYHM9fQ3wSN1fXQ+vbnv7ZtjuDwM3ArfV438C\nXF4P/yHwgXr4nwN/WA9fDtxcD59f7/szgHPrfxMjbW/XLNv8WeB99fBiYNWw7meqFx99H1jatX//\n6TDuZ+AngIuA+7qm9W2/At+ol436s285qfK1/Qc6zT/u64A7usavBq5uu1x93L4/p3on8EPAhnra\nBuChevjTwLu6ln+onv8u4NNd049bbi51VO9J+CpwKXBb/Q95B7Bw8j6mevLq6+rhhfVyMXm/dy83\nFztgZR2AMWn6UO5njr39bk29324DfnpY9zMwPinw+7Jf63nf6Zp+3HK9dPO9SWdoX6NY/4y9ENgM\nrM/Mp+pZTwPr6+Hptn8+/V0+CfwGcKQeXwvszMxD9Xh32Y9uVz3/hXr5+bS9UNVOtwP/pW7K+kxE\nLGNI93NmPgH8DrAVeIpqv21h+PdzR7/269n18OTpPZvvgT+UImI58KfABzNzV/e8rA7tQ3EtbUS8\nDdiWmVvaLsuALaT62f8HmXkh8CLVT/2jhmw/rwbeTnWg2wgsAy5rtVAtaXu/zvfA7+k1ivNJRCyi\nCvsbMvML9eRnImJDPX8DsK2ePt32z5e/y+uBn42IR4HPUzXr/C6wKiI672roLvvR7arnrwSeZf5s\nb8fjwOOZubkev4XqADCs+/kfAt/PzO2ZeRD4AtW+H/b93NGv/fpEPTx5es/me+AP1WsU6zPufwQ8\nmJnXdM36ItA5U38FVdt+Z/ov1Wf7LwFeqH863gG8OSJW17WrN9fT5pTMvDozz8nMcap997XMfDdw\nJ/COerHJ29v5O7yjXj7r6ZfXV3ecC7yS6uTWnJSZTwOPRcSr6kn/AHiAId3PVE05l0TEmfW/8c72\nDvV+7tKX/VrP2xURl9R/x1/q+q7etH2Cow8nSN5KdTXL94CPtV2e09yWN1D93LsXuKfu3krVfvlV\n4LvAV4A19fIBfKre9m8Dm7q+61eAh+vul9veth62/U0cu0rn5VT/kR8G/jtwRj19ST3+cD3/5V2f\n/1j9d3iIk7xyoaXtvQCYqPf1n1FdjTG0+xn4LeA7wH3AH1NdaTN0+xm4ieo8xUGqX3Lv7ed+BTbV\nf8PvAb/HpBP/s3U+WkGSCjHfm3QkST0y8CWpEAa+JBXCwJekQhj4klQIA19DLyIOR8Q9XV3fnqoa\nEePdT0aU5rKFsy8izXt/l5kXtF0IqW3W8FWsiHg0In67fr74NyLiFfX08Yj4Wv2M8q9GxEvr6esj\n4taI+Fbd/Xj9VSMR8Z/r573/dUQsrZf/V1G92+DeiPh8S5spHWXgqwRLJzXpvLNr3guZ+feo7lr8\nZD3tPwGfzcwfA24Arq2nXwv8z8x8NdWzb+6vp78S+FRm/giwE/jH9fSPAhfW3/P+pjZO6pV32mro\nRcSezFw+xfRHgUsz85H6oXVPZ+baiNhB9fzyg/X0pzJzNCK2A+dk5v6u7xgHvpyZr6zHfxNYlJkf\nj4gvAXuoHp3wZ5m5p+FNlWZkDV+ly2mGT8b+ruHDHDs39jNUz0q5CLi768mQUisMfJXunV39u+rh\n/0319E6AdwP/qx7+KvABOPoe3pXTfWlELABekpl3Ar9J9YjfE35lSINkjUMlWBoR93SNfykzO5dm\nro6Ie6lq6e+qp/1LqrdR/TrVm6l+uZ5+FXBdRLyXqib/AaonI05lBPhcfVAI4NrM3Nm3LZJOgW34\nKlbdhr8pM3e0XRZpEGzSkaRCWMOXpEJYw5ekQhj4klQIA1+SCmHgS1IhDHxJKsT/B2bRjtpKUvay\nAAAAAElFTkSuQmCC\n",
       "text": [
        "<matplotlib.figure.Figure at 0x7fcd0ba19c50>"
       ]
      }
     ],
     "prompt_number": 6
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def predict(x_data, w, b):\n",
      "    xMat = np.mat(x_data)\n",
      "    w = np.mat(w)\n",
      "    print (xMat.shape)\n",
      "    print (w.shape)\n",
      "    print ((xMat*w).shape)\n",
      "    return [1 if x> 0.5 else 0 for x in sigmoid(xMat*w+b)]\n",
      "\n",
      "predictions = predict(x_train, w, b)\n",
      "print (\"Train Acc.: %.4f\" % (np.sum(predictions==y_train)/len(y_train)))\n",
      "predictions = predict(x_test, w, b)\n",
      "print (\"Test Acc.: %.4f\" % (np.sum(predictions==y_test)/len(y_test)))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "(80, 4)\n",
        "(4, 1)\n",
        "(80, 1)\n",
        "Train Acc.: 1.0000\n",
        "(20, 4)\n",
        "(4, 1)\n",
        "(20, 1)\n",
        "Test Acc.: 1.0000\n"
       ]
      }
     ],
     "prompt_number": 7
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [],
     "language": "python",
     "metadata": {},
     "outputs": []
    }
   ],
   "metadata": {}
  }
 ]
}